import json
import requests
import time
import html
import datasets
import argparse
from functools import partial


def str_to_list(input_string):
    # 去掉字符串两端的空格并按逗号分割，同时去掉每个元素的多余引号和空格
    return json.loads(input_string)


def set_args():
    parsers = argparse.ArgumentParser()
    parsers.add_argument("--bot_id", default="1404001", help="Request bot_id")
    parsers.add_argument("--input_file", default="annotation/20241117/idle_ai_sale_utt_level_annotation_20241117_sample2000.jsonl", help="Request input file")
    parsers.add_argument(
        "--output_file", default="annotation/20241117/idle_ai_sale_utt_level_annotation_20241117_sample2000_response.jsonl", help="Request output file"
    )
    parsers.add_argument("--in_variables", default="promptVariables", help="Request in_variables")
    parsers.add_argument(
        "--out_variables", type=str_to_list, default=["titleDesc", "cpv", "qaPairs", "salePrice", "recentAgentOffer", "context"], help="Request out_variables"
    )
    parsers.add_argument("--num_proc", default=30, type=int, help="Request num_proc")
    parsers.add_argument("--breakpoint", action="store_true", help="Whether to perform breakpoint reconnection")
    parsers.add_argument("--prompt_template", action="store_true", help="Whether to perform prompt template.")
    parsers.add_argument("--template_key", default="customPromptTemplateId", help="Prompt template key")
    parsers.add_argument("--template_value", default=54775, type=int, help="Prompt template value")

    parsers.add_argument("--online", action="store_true", help="Whether to perform online.")

    return parsers.parse_args()


class IdleAiAgent:
    def __init__(self, bot_id: str, online: bool = False) -> None:
        self.bot_id = bot_id
        if online:
            self.url = "http://idleai.alibaba-inc.com/bot/execute"
        else:
            self.url = "http://pre-idleai.alibaba-inc.com/bot/execute"

    def call(self, input_params: dict[str, str], retry_times=5) -> str:
        for k, v in input_params.items():
            if not isinstance(v, str):
                input_params[k] = json.dumps(v, ensure_ascii=False, indent=4)
        load = {"botId": self.bot_id, "params": input_params}
        response = ""
        for i in range(retry_times):
            try:
                response = requests.post(self.url, json=load, headers={"Content-Type": "application/json"})

                response = json.loads(response.text)
                if response["success"]:
                    return html.unescape(response["data"])
            except Exception as e:
                pass
            time.sleep(5)
        print(response)
        # raise Exception("call idleai failed")
        return "error"


def request_fun(args, agent, exams):
    new_res = []
    for param in exams[args.in_variables]:
        while type(param) == str:
            param = json.loads(param)

        try:
            request_data = {k: param.get(k) for k in args.out_variables}
            if args.prompt_template:
                request_data[args.template_key] = args.template_value
        except:
            print(param)
            continue

        response = agent.call(request_data)
        new_res.append(response)

    exams["new_res"] = new_res
    return exams


def request_checkpoint_fun(args, agent, exams):
    new_res = []

    for tmp_res, param in zip(exams["new_res"], exams[args.in_variables]):
        if tmp_res != "error":  # 没有报错 无须再次请求
            new_res.append(tmp_res)
            continue

        while type(param) == str:
            param = json.loads(param)

        request_data = {k: param.get(k) for k in args.out_variables}
        if args.prompt_template:
            request_data[args.template_key] = args.template_value

        response = agent.call(request_data)
        new_res.append(response)

    exams["new_res"] = new_res
    return exams


def main():
    args = set_args()

    agent = IdleAiAgent(args.bot_id, args.online)
    ds = datasets.load_dataset("json", data_files=args.input_file)["train"]

    if args.breakpoint:
        map_fun = partial(request_checkpoint_fun, args, agent)
    else:
        map_fun = partial(request_fun, args, agent)
    ds = ds.map(map_fun, batched=True, num_proc=args.num_proc, load_from_cache_file=False)

    ds.to_json(args.output_file, orient="records", lines=True, force_ascii=False)


if __name__ == "__main__":
    main()